using Flux,
    DiffEqFlux,
    DifferentialEquations,
    XLSX,
    DataFrames,
    DiffEqSensitivity,
    Plots, Optim, OrdinaryDiffEq

using Random, JLD

df = DataFrame(XLSX.readdata(
    "/Users/urielyang/OneDrive - Emory University/Honors/Data_SK.xlsx",
    "Sheet!A1:F101",
))
# keep only infected and recovered
final_data = df[2:end, [5, 6]]
#data passed into the "train" function
global x = convert(Array{Float64,1}, final_data[:, 1])
global y = convert(Array{Float64,1}, final_data[:, 2])

#neural network structure 4x10x1 with relu as the activatioin function
nn = FastChain(FastDense(3, 20, relu), FastDense(20, 20, relu), FastDense(20, 1, sigmoid))
#nn_1 = Chain(LSTM(4, 10), Dense(10, 1, sigmoid))
#initial_params(nn_1)
#set the initial parameter so that the inital quarantine strength is between 0 and 1
Random.seed!(1)
p_init = rand(541)/ 10000
#nn_1(u0)[1]
#problem setup
function SIRQ!(du, u, p, t)
    s, i, r, T = u
    n = s + i + r + T
    #β = 1
    #γ = 0.025
    β = p[1]
    γ = p[2]
    nn_p = p[3:end]
    du[1] = ds = -β * s * i / n
    du[2] = di = β * s * i / n - (γ+ nn(u[2:4], p)[1]) * i
    du[3] = dr = γ * i
    du[4] = dT = nn(u[2:4], p)[1] * i
end

#initial condition
u0 = [60e6, 500.0, 10.0, 10.0]
p = [[1, 0.025]; p_init]
#p = p_init
ts = 0.0
tend = 99.0
tspan = (ts, tend)
prob = ODEProblem(SIRQ!, u0, tspan, p)

function predict(p)
    return Array(solve(prob, Tsit5(), u0=u0, p=p, saveat = ts:1:tend))
end

temp = predict(p)

#loss function
loss_fun = function (p)
    global cur_p = p
    prediction = predict(p)
    print(length(prediction))
    #print("infected: ")
    #print(round(sum((log.(prediction[2, 1:80]) - log.(x[1:80])) .^ 2), digits=3))
    #print("   ")
    #print("recovered: ")
    #println(round(sum((log.(prediction[3, 1:80]) - log.(y[1:80])) .^ 2), digits=3))
    se = sum((log.(prediction[2, 1:80]) - log.(x[1:80])) .^ 2) + sum((log.(prediction[3, 1:80]) - log.(y[1:80])) .^ 2)
    return se
end
loss_fun(p)


#train the model
res = DiffEqFlux.sciml_train(loss_fun, p,  BFGS(initial_stepnorm = 0.02), maxiters = 100)

#plot the result
t_step = ts:1:tend
t_step_1 = 0.0:1:39.0
scatter(t_step, x, color = [1], label = "infected")
scatter!(t_step, y, color = [2], label = "recovered")

prediction = predict(cur_p)

plot!(t_step, prediction[2, :], color = [3], label = "predicted infected", legend = false)
plot!(t_step, prediction[3, :], color = [4], label = "predicted recovered")

q_SK = nn(prediction[2:4, :], cur_p)
scatter(t_step, q_strength[1, :], ylims = (0, 1.5), legend = false)

save("/Users/urielyang/OneDrive - Emory University/Honors/predres_SK.jld", "pred_SK", prediction, "q_SK", q_SK)
